"""
Runs benchmark evaluations with the BEIR dataset.

Install txtai and the following dependencies to run:
    pip install txtai pytrec_eval rank-bm25 bm25s elasticsearch psutil
"""

import argparse
import csv
import json
import os
import pickle
import sqlite3
import time

import psutil
import yaml

import numpy as np

from bm25s import BM25 as BM25Sparse
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from pytrec_eval import RelevanceEvaluator
from rank_bm25 import BM25Okapi
from tqdm.auto import tqdm

from txtai.embeddings import Embeddings
from txtai.pipeline import LLM, RAG, Similarity, Tokenizer
from txtai.scoring import ScoringFactory


class Index:
    """
    Base index definition. Defines methods to index and search a dataset.
    """

    def __init__(self, path, config, output, refresh):
        """
        Creates a new index.

        Args:
            path: path to dataset
            config: path to config file
            output: path to store index
            refresh: overwrites existing index if True, otherwise existing index is loaded
        """

        self.path = path
        self.config = config
        self.output = output
        self.refresh = refresh

        # Build and save index
        self.backend = self.index()

    def __call__(self, limit, filterscores=True):
        """
        Main evaluation logic. Loads an index, runs the dataset queries and returns the results.

        Args:
            limit: maximum results
            filterscores: if exact matches should be filtered out

        Returns:
            search results
        """

        uids, queries = self.load()

        # Run queries in batches
        offset, results = 0, {}
        for batch in self.batch(queries, 256):
            for i, r in enumerate(self.search(batch, limit + 1)):
                # Get result as list of (id, score) tuples
                r = list(r)
                r = [(x["id"], x["score"]) for x in r] if r and isinstance(r[0], dict) else r

                if filterscores:
                    r = [(uid, score) for uid, score in r if uid != uids[offset + i]][:limit]

                results[uids[offset + i]] = dict(r)

            # Increment offset
            offset += len(batch)

        return results

    def search(self, queries, limit):
        """
        Runs a search for a set of queries.

        Args:
            queries: list of queries to run
            limit: maximum results

        Returns:
            search results
        """

        return self.backend.batchsearch(queries, limit)

    def index(self):
        """
        Indexes a dataset.
        """

        raise NotImplementedError

    def rows(self):
        """
        Iterates over the dataset yielding a row at a time for indexing.
        """

        # Data file
        path = f"{self.path}/corpus.jsonl"

        # Get total count
        with open(path, encoding="utf-8") as f:
            total = sum(1 for _ in f)

        # Yield data
        with open(path, encoding="utf-8") as f:
            for line in tqdm(f, total=total):
                row = json.loads(line)
                text = f'{row["title"]}. {row["text"]}' if row.get("title") else row["text"]
                if text:
                    yield (row["_id"], text, None)

    def load(self):
        """
        Loads queries for the dataset. Returns a list of expected result ids and input queries.

        Returns:
            (result ids, input queries)
        """

        with open(f"{self.path}/queries.jsonl", encoding="utf-8") as f:
            data = [json.loads(query) for query in f]
            uids, queries = [x["_id"] for x in data], [x["text"] for x in data]

        return uids, queries

    def batch(self, data, size):
        """
        Splits data into equal sized batches.

        Args:
            data: input data
            size: batch size

        Returns:
            data split into equal size batches
        """

        return [data[x : x + size] for x in range(0, len(data), size)]

    def readconfig(self, key, default):
        """
        Reads configuration from a config file. Returns default configuration
        if config file is not found or config key isn't present.

        Args:
            key: configuration key to lookup
            default: default configuration

        Returns:
            config if found, otherwise returns default config
        """

        if self.config and os.path.exists(self.config):
            # Read configuration
            with open(self.config, "r", encoding="utf-8") as f:
                # Check for config
                config = yaml.safe_load(f)
                if key in config:
                    return config[key]

        return default


class Embed(Index):
    """
    Embeddings index using txtai.
    """

    def index(self):
        if os.path.exists(self.output) and not self.refresh:
            embeddings = Embeddings()
            embeddings.load(self.output)
        else:
            # Read configuration
            config = self.readconfig("embeddings", {"batch": 8192, "encodebatch": 128, "faiss": {"quantize": True, "sample": 0.05}})

            # Build index
            embeddings = Embeddings(config)
            embeddings.index(self.rows())
            embeddings.save(self.output)

        return embeddings


class Hybrid(Index):
    """
    Hybrid embeddings + BM25 index using txtai.
    """

    def index(self):
        if os.path.exists(self.output) and not self.refresh:
            embeddings = Embeddings()
            embeddings.load(self.output)
        else:
            # Read configuration
            config = self.readconfig(
                "hybrid",
                {
                    "batch": 8192,
                    "encodebatch": 128,
                    "faiss": {"quantize": True, "sample": 0.05},
                    "scoring": {"method": "bm25", "terms": True, "normalize": True},
                },
            )

            # Build index
            embeddings = Embeddings(config)
            embeddings.index(self.rows())
            embeddings.save(self.output)

        return embeddings


class RetrievalAugmentedGeneration(Embed):
    """
    Retrieval augmented generation (RAG) using txtai.
    """

    def __init__(self, path, config, output, refresh):
        # Parent logic
        super().__init__(path, config, output, refresh)

        # Read LLM configuration
        llm = self.readconfig("llm", {})

        # Read RAG configuration
        rag = self.readconfig("rag", {})

        # Load RAG pipeline
        self.rag = RAG(self.backend, LLM(**llm), output="reference", **rag)

    def search(self, queries, limit):
        # Set context window size to limit and run
        self.rag.context = limit
        return [[(x["reference"], 1)] for x in self.rag(queries, maxlength=4096)]


class Score(Index):
    """
    BM25 index using txtai.
    """

    def index(self):
        # Read configuration
        config = self.readconfig("scoring", {"method": "bm25", "terms": True})

        # Create scoring instance
        scoring = ScoringFactory.create(config)

        output = os.path.join(self.output, "scoring")
        if os.path.exists(output) and not self.refresh:
            scoring.load(output)
        else:
            scoring.index(self.rows())
            scoring.save(output)

        return scoring


class Similar(Index):
    """
    Search data using a similarity pipeline.
    """

    def index(self):
        # Load similarity pipeline
        model = Similarity(**self.readconfig("similar", {}))

        # Get datasets
        data = list(self.rows())
        ids = [x[0] for x in data]
        texts = [x[1] for x in data]

        # Encode texts
        data = model.encode(texts, "data")

        return (ids, data, model)

    def search(self, queries, limit):
        # Unpack backend
        ids, data, model = self.backend

        # Run model inference
        results = []
        for result in model(queries, data, limit=limit):
            results.append([(ids[x], score) for x, score in result])

        return results


class Rerank(Embed):
    """
    Embeddings index using txtai combined with a similarity pipeline
    """

    def index(self):
        # Build embeddings index
        embeddings = super().index()

        # Combine similar pipeline with embeddings
        model = Similar(self.path, self.config, self.output, self.refresh)
        return model.index() + (embeddings,)

    def search(self, queries, limit):
        # Unpack backend
        ids, data, model, embeddings = self.backend

        # Run initial query
        indices = []
        for r in embeddings.batchsearch(queries, limit * 10):
            indices.append({x: ids.index(uid) for x, (uid, _) in enumerate(r)})

        # Run model inference
        results = []
        for x, query in enumerate(queries):
            queue = data[list(indices[x].values())]
            if len(queue) > 0:
                result = model(query, queue, limit=limit)
                results.append([(ids[indices[x][i]], score) for i, score in result])

        return results


class RankBM25(Index):
    """
    BM25 index using rank-bm25.
    """

    def search(self, queries, limit):
        ids, backend = self.backend
        tokenizer, results = Tokenizer(), []
        for query in queries:
            scores = backend.get_scores(tokenizer(query))
            topn = np.argsort(scores)[::-1][:limit]
            results.append([(ids[x], scores[x]) for x in topn])

        return results

    def index(self):
        output = os.path.join(self.output, "rank")
        if os.path.exists(output) and not self.refresh:
            with open(output, "rb") as f:
                ids, model = pickle.load(f)
        else:
            # Tokenize data
            tokenizer, data = Tokenizer(), []
            for uid, text, _ in self.rows():
                data.append((uid, tokenizer(text)))

            ids = [uid for uid, _ in data]
            model = BM25Okapi([text for _, text in data])

            # Save model
            with open(output, "wb") as out:
                pickle.dump(model, out)

        return ids, model


class BM25S(Index):
    """
    BM25 as implemented by bm25s
    """

    def __init__(self, path, config, output, refresh):
        # Corpus ids
        self.ids = None

        # Parent logic
        super().__init__(path, config, output, refresh)

    def search(self, queries, limit):
        tokenizer = Tokenizer()
        results, scores = self.backend.retrieve([tokenizer(x) for x in queries], corpus=self.ids, k=limit)

        # List of queries => list of matches (id, score)
        x = []
        for a, b in zip(results, scores):
            x.append([(str(c), float(d)) for c, d in zip(a, b)])

        return x

    def index(self):
        tokenizer = Tokenizer()
        ids, texts = [], []

        for uid, text, _ in self.rows():
            ids.append(uid)
            texts.append(text)

        self.ids = ids

        if os.path.exists(self.output) and not self.refresh:
            model = BM25Sparse.load(self.output)
        else:
            model = BM25Sparse(method="lucene", k1=1.2, b=0.75)
            model.index([tokenizer(x) for x in texts], leave_progress=False)
            model.save(self.output)

        return model


class SQLiteFTS(Index):
    """
    BM25 index using SQLite's FTS extension.
    """

    def search(self, queries, limit):
        tokenizer, results = Tokenizer(), []
        for query in queries:
            query = tokenizer(query)
            query = " OR ".join([f'"{q}"' for q in query])

            self.backend.execute(
                f"SELECT id, bm25(textindex) * -1 score FROM textindex WHERE text MATCH ? ORDER BY bm25(textindex) LIMIT {limit}", [query]
            )

            results.append(list(self.backend))

        return results

    def index(self):
        if os.path.exists(self.output) and not self.refresh:
            # Load existing database
            connection = sqlite3.connect(self.output)
        else:
            # Delete existing database
            if os.path.exists(self.output):
                os.remove(self.output)

            # Create new database
            connection = sqlite3.connect(self.output)

            # Tokenize data
            tokenizer, data = Tokenizer(), []
            for uid, text, _ in self.rows():
                data.append((uid, " ".join(tokenizer(text))))

            # Create table
            connection.execute("CREATE VIRTUAL TABLE textindex using fts5(id, text)")

            # Load data and build index
            connection.executemany("INSERT INTO textindex VALUES (?, ?)", data)

            connection.commit()

        return connection.cursor()


class Elastic(Index):
    """
    BM25 index using Elasticsearch.
    """

    def search(self, queries, limit):
        # Generate bulk queries
        request = []
        for query in queries:
            req_head = {"index": "textindex", "search_type": "dfs_query_then_fetch"}
            req_body = {
                "_source": False,
                "query": {"multi_match": {"query": query, "type": "best_fields", "fields": ["text"], "tie_breaker": 0.5}},
                "size": limit,
            }
            request.extend([req_head, req_body])

        # Run ES query
        response = self.backend.msearch(body=request, request_timeout=600)

        # Read responses
        results = []
        for resp in response["responses"]:
            result = resp["hits"]["hits"]
            results.append([(r["_id"], r["_score"]) for r in result])

        return results

    def index(self):
        es = Elasticsearch("http://localhost:9200")

        # Delete existing index
        # pylint: disable=W0702
        try:
            es.indices.delete(index="textindex")
        except:
            pass

        bulk(es, ({"_index": "textindex", "_id": uid, "text": text} for uid, text, _ in self.rows()))
        es.indices.refresh(index="textindex")

        return es


def relevance(path):
    """
    Loads relevance data for evaluation.

    Args:
        path: path to dataset test file

    Returns:
        relevance data
    """

    rel = {}
    with open(f"{path}/qrels/test.tsv", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
        next(reader)

        for row in reader:
            queryid, corpusid, score = row[0], row[1], int(row[2])
            if queryid not in rel:
                rel[queryid] = {corpusid: score}
            else:
                rel[queryid][corpusid] = score

    return rel


def create(method, path, config, output, refresh):
    """
    Creates a new index.

    Args:
        method: indexing method
        path: path to dataset
        config: path to config file
        output: path to store index
        refresh: overwrites existing index if True, otherwise existing index is loaded

    Returns:
        Index
    """

    if method == "hybrid":
        return Hybrid(path, config, output, refresh)
    if method == "rag":
        return RetrievalAugmentedGeneration(path, config, output, refresh)
    if method == "scoring":
        return Score(path, config, output, refresh)
    if method == "rank":
        return RankBM25(path, config, output, refresh)
    if method == "bm25s":
        return BM25S(path, config, output, refresh)
    if method == "sqlite":
        return SQLiteFTS(path, config, output, refresh)
    if method == "es":
        return Elastic(path, config, output, refresh)
    if method == "similar":
        return Similar(path, config, output, refresh)
    if method == "rerank":
        return Rerank(path, config, output, refresh)

    # Default
    return Embed(path, config, output, refresh)


def compute(results):
    """
    Computes metrics using the results from an evaluation run.

    Args:
        results: evaluation results

    Returns:
        metrics
    """

    metrics = {}
    for r in results:
        for metric in results[r]:
            if metric not in metrics:
                metrics[metric] = []

            metrics[metric].append(results[r][metric])

    return {metric: round(np.mean(values), 5) for metric, values in metrics.items()}


def evaluate(methods, path, args):
    """
    Runs an evaluation.

    Args:
        methods: list of indexing methods to test
        path: path to dataset
        args: command line arguments

    Returns:
        {calculated performance metrics}
    """

    print(f"------ {os.path.basename(path)} ------")

    # Performance stats
    performance = {}

    # Calculate stats for each model type
    topk = args.topk
    evaluator = RelevanceEvaluator(relevance(path), {f"ndcg_cut.{topk}", f"map_cut.{topk}", f"recall.{topk}", f"P.{topk}"})
    for method in methods:
        # Stats for this source
        stats = {}
        performance[method] = stats

        # Create index and get results
        start = time.time()
        output = args.output if args.output else f"{path}/{method}"
        index = create(method, path, args.config, output, args.refresh)

        # Add indexing metrics
        stats["index"] = round(time.time() - start, 2)
        stats["memory"] = int(psutil.Process().memory_info().rss / (1024 * 1024))
        stats["disk"] = int(sum(d.stat().st_size for d in os.scandir(output) if d.is_file()) / 1024) if os.path.isdir(output) else 0

        print("INDEX TIME =", time.time() - start)
        print(f"MEMORY USAGE = {stats['memory']} MB")
        print(f"DISK USAGE = {stats['disk']} KB")

        start = time.time()
        results = index(topk)

        # Add search metrics
        stats["search"] = round(time.time() - start, 2)
        print("SEARCH TIME =", time.time() - start)

        # Calculate stats
        metrics = compute(evaluator.evaluate(results))

        # Add accuracy metrics
        for stat in [f"ndcg_cut_{topk}", f"map_cut_{topk}", f"recall_{topk}", f"P_{topk}"]:
            stats[stat] = metrics[stat]

        # Print model stats
        print(f"------ {method} ------")
        print(f"NDCG@{topk} =", metrics[f"ndcg_cut_{topk}"])
        print(f"MAP@{topk} =", metrics[f"map_cut_{topk}"])
        print(f"Recall@{topk} =", metrics[f"recall_{topk}"])
        print(f"P@{topk} =", metrics[f"P_{topk}"])

    print()
    return performance


def benchmarks(args):
    """
    Main benchmark execution method.

    Args:
        args: command line arguments
    """

    # Directory where BEIR datasets are stored
    directory = args.directory if args.directory else "beir"

    if args.sources and args.methods:
        sources, methods = args.sources.split(","), args.methods.split(",")
        mode = "a"
    else:
        # Default sources and methods
        sources = [
            "trec-covid",
            "nfcorpus",
            "nq",
            "hotpotqa",
            "fiqa",
            "arguana",
            "webis-touche2020",
            "quora",
            "dbpedia-entity",
            "scidocs",
            "fever",
            "climate-fever",
            "scifact",
        ]
        methods = ["embed", "hybrid", "rag", "scoring", "rank", "bm25s", "sqlite", "es", "similar", "rerank"]
        mode = "w"

    # Run and save benchmarks
    with open("benchmarks.json", mode, encoding="utf-8") as f:
        for source in sources:
            # Run evaluations
            results = evaluate(methods, f"{directory}/{source}", args)

            # Save as JSON lines output
            for method, stats in results.items():
                stats["source"] = source
                stats["method"] = method
                stats["name"] = args.name if args.name else method

                json.dump(stats, f)
                f.write("\n")


if __name__ == "__main__":
    # Command line parser
    parser = argparse.ArgumentParser(description="Benchmarks")
    parser.add_argument("-c", "--config", help="path to config file", metavar="CONFIG")
    parser.add_argument("-d", "--directory", help="root directory path with datasets", metavar="DIRECTORY")
    parser.add_argument("-m", "--methods", help="comma separated list of methods", metavar="METHODS")
    parser.add_argument("-n", "--name", help="name to assign to this run, defaults to method name", metavar="NAME")
    parser.add_argument("-o", "--output", help="index output directory path", metavar="OUTPUT")
    parser.add_argument(
        "-r",
        "--refresh",
        help="refreshes index if set, otherwise uses existing index if available",
        action="store_true",
    )
    parser.add_argument("-s", "--sources", help="comma separated list of data sources", metavar="SOURCES")
    parser.add_argument("-t", "--topk", help="top k results to use for the evaluation", metavar="TOPK", type=int, default=10)

    # Calculate benchmarks
    benchmarks(parser.parse_args())
